import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score
from scipy import stats
from sklearn.utils import resample
import random
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 

# Data Import
loss = pd.read_csv('./data/wandb_loss.csv')
ood = pd.read_csv('./data/out_domain_bench.csv')

# Functions
def metrics(df, preds, group_by=None):
    """
    Calculate MCC, Accuracy, F1 for predictions.

    Args:
        df (pd.DataFrame): The input DataFrame containing true and predicted labels.
        preds (list): List of column names containing model predictions.
        group_by (str, optional): Column name to group by ('dataset' or 'task'). Defaults to None.

    Returns:
        pd.DataFrame: DataFrame with calculated metrics, optionally grouped by `group_by`.
    """
    true_col = 'entailment'
    
    def get_metrics(y_true, y_pred):
        return {
            'MCC': matthews_corrcoef(y_true, y_pred),
            'Accuracy': accuracy_score(y_true, y_pred),
            'F1': f1_score(y_true, y_pred, average='weighted')
        }
    
    results = []
    
    if group_by not in ['dataset', 'task']:
        for col in preds:
            metrics = get_metrics(df[true_col], df[col])
            metrics['Column'] = col
            results.append(metrics)
    else:
        for col in preds:
            for group_name, group in df.groupby(group_by):
                metrics = get_metrics(group[true_col], group[col])
                metrics['Column'] = col
                metrics[group_by.capitalize()] = group_name
                results.append(metrics)
    
    results_df = pd.DataFrame(results)
    
    if group_by in ['dataset', 'task']:
        return results_df.set_index(['Column', group_by.capitalize()])
    else:
        return results_df.set_index('Column')

def bootstrapped_errors(y_true, y_pred, n_bootstrap=1000):
    """
    Calculate bootstrapped standard errors for MCC, Accuracy, and F1.

    Args:
        y_true (array-like): True labels.
        y_pred (array-like): Predicted labels.
        n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 1000.

    Returns:
        dict: Standard errors for MCC, Accuracy, and F1.
    """
    mcc_scores = []
    accuracy_scores = []
    f1_scores = []
    
    for _ in range(n_bootstrap):
        # Resample with replacement
        y_true_resampled, y_pred_resampled = resample(y_true, y_pred)
        
        # Calculate metrics for the resampled data
        mcc_scores.append(matthews_corrcoef(y_true_resampled, y_pred_resampled))
        accuracy_scores.append(accuracy_score(y_true_resampled, y_pred_resampled))
        f1_scores.append(f1_score(y_true_resampled, y_pred_resampled, average='weighted'))
    
    # Calculate standard errors
    return {
        'MCC_SE': np.std(mcc_scores),
        'Accuracy_SE': np.std(accuracy_scores),
        'F1_SE': np.std(f1_scores)
    }

def metrics_with_errors(df, preds, n_bootstrap=1000, group_by=None):
    """
    Calculate metrics and bootstrapped standard errors for predictions, optionally grouped.

    Args:
        df (pd.DataFrame): The input DataFrame containing true and predicted labels.
        preds (list): List of column names containing model predictions.
        n_bootstrap (int, optional): Number of bootstrap samples. Defaults to 1000.
        group_by (str, optional): Column name to group by ('dataset' or 'task'). Defaults to None.

    Returns:
        pd.DataFrame: Combined DataFrame of metrics, standard errors, and confidence intervals.
    """
    # Step 1: Calculate metrics for each model
    metrics_df = metrics(df, preds, group_by=group_by)

    # Step 2: Calculate bootstrapped errors for each model or group
    errors = []
    if group_by not in ['dataset', 'task']:
        for col in preds:
            y_true = df['entailment']
            y_pred = df[col]
            errors_dict = bootstrapped_errors(y_true, y_pred, n_bootstrap=n_bootstrap)
            errors_dict['Column'] = col
            errors.append(errors_dict)
    else:
        for col in preds:
            for group_name, group in df.groupby(group_by):
                y_true = group['entailment']
                y_pred = group[col]
                errors_dict = bootstrapped_errors(y_true, y_pred, n_bootstrap=n_bootstrap)
                errors_dict['Column'] = col
                errors_dict[group_by.capitalize()] = group_name
                errors.append(errors_dict)

    errors_df = pd.DataFrame(errors)

    if group_by in ['dataset', 'task']:
        errors_df = errors_df.set_index(['Column', group_by.capitalize()])
    else:
        errors_df = errors_df.set_index('Column')

    # Step 3: Merge metrics and errors DataFrames
    combined_df = metrics_df.merge(errors_df, left_index=True, right_index=True)

    # Step 4: Calculate confidence intervals (upper and lower bounds)
    combined_df['MCC_Lower'] = combined_df['MCC'] - combined_df['MCC_SE']
    combined_df['MCC_Upper'] = combined_df['MCC'] + combined_df['MCC_SE']

    combined_df['Accuracy_Lower'] = combined_df['Accuracy'] - combined_df['Accuracy_SE']
    combined_df['Accuracy_Upper'] = combined_df['Accuracy'] + combined_df['Accuracy_SE']

    combined_df['F1_Lower'] = combined_df['F1'] - combined_df['F1_SE']
    combined_df['F1_Upper'] = combined_df['F1'] + combined_df['F1_SE']

    return combined_df

def format_f1_table(df, index_col='Model', task_col='Task', value_col='F1', se_col='F1_SE', task_order=None, model_order=None):
    """
    Pivot df to wide form with F1 and F1_SE interleaved and scaled to percentages.
    
    Parameters:
    df : pd.DataFrame
        Long-format dataframe containing index_col, task_col, value_col, se_col.
    index_col : str
        Column name for the row index (default 'Model').
    task_col : str
        Column name containing task names (default 'Task').
    value_col : str
        Column name for the main metric (default 'F1').
    se_col : str
        Column name for the standard error metric (default 'F1_SE').
    task_order : list[str] or None
        Desired order of tasks (e.g. ["UKP Stance", "UKP Topic", ...]).
        If None, tasks are kept in the pivot's discovered order.
    model_order : list[str] or None
        Desired order of models for the index (rows). If provided, the result
        will be reindexed to this order (rows not in the list will be dropped).
    
    Returns:
    pd.DataFrame
        Formatted wide dataframe with columns: [Model, Task1, Task1_SE, Task2, Task2_SE, ...]
        All numeric values multiplied by 100 and rounded to 2 decimals.
    """
    # Pivot (use pivot_table to guard against duplicate index/task pairs)
    pivot = df[[index_col, task_col, value_col, se_col]].pivot_table(
        index=index_col,
        columns=task_col,
        values=[value_col, se_col],
        aggfunc='mean'   # change if you prefer a different aggregation
    )

    # Flatten MultiIndex columns: map (metric, task) -> 'task' or 'task_SE'
    def flat_name(metric, task):
        if metric == se_col:
            return f"{task}_SE"
        elif metric == value_col:
            return f"{task}"
        else:
            return f"{metric}_{task}"

    pivot.columns = [flat_name(metric, task) for metric, task in pivot.columns]

    # Reset index to make Model a column
    out = pivot.reset_index()

    # If user didn't provide task_order, infer it from current columns (preserve order)
    if task_order is None:
        seen = []
        for c in out.columns:
            if c == index_col:
                continue
            base = c[:-3] if c.endswith('_SE') else c
            if base not in seen:
                seen.append(base)
        task_order = seen

    # Build interleaved column ordering and ensure missing columns exist (fill with NaN)
    desired_cols = [index_col]
    for t in task_order:
        desired_cols.append(t)
        desired_cols.append(f"{t}_SE")

    for c in desired_cols:
        if c not in out.columns:
            out[c] = pd.NA

    out = out[desired_cols]

    # Reindex rows to the requested model_order (if provided). This will keep the exact order and include NaNs for missing combos.
    if model_order is not None:
        out = out.set_index(index_col).reindex(model_order).reset_index()

    # Multiply numerical columns (all except index_col) by 100 and round to 2 decimals.
    # Convert to numeric where possible to avoid issues with NA types
    numeric_cols = out.columns.drop(index_col)
    out[numeric_cols] = out[numeric_cols].apply(pd.to_numeric, errors='coerce').mul(100).round(2)

    return out

def compute_stability(labels_df, label_cols=None, group_by=None):
    """
    Compute stability scores from a labels DataFrame, grouping by the provided columns.

    Parameters
    - labels_df: pd.DataFrame produced by `stability_benchmark_all_labels`.
                 Must contain the label columns and any grouping columns you want to use.
    - label_cols: list of str, names of the label columns to use. If None,
                  defaults to ['hypothesis_label','alt1_label','alt2_label','alt3_label'].
    - group_by: str or list of str. Column name(s) to group by. If None, defaults to ['Checkpoint'].

    Returns
    - pd.DataFrame with columns = group_by (if provided) + ['Stability'].
      Each row contains the stability score for that group.
    """
    if label_cols is None:
        label_cols = ['hypothesis_label', 'alt1_label', 'alt2_label', 'alt3_label']

    # normalize group_by to a list
    if group_by is None:
        group_by = ['Checkpoint']
    elif isinstance(group_by, str):
        group_by = [group_by]
    elif isinstance(group_by, (list, tuple)):
        group_by = list(group_by)
    else:
        raise TypeError("group_by must be None, a string, or a list/tuple of strings")

    # verify group_by columns exist
    missing = [c for c in group_by if c not in labels_df.columns]
    if missing:
        raise ValueError(f"The following group_by columns are not present in labels_df: {missing}")

    n_prompts = len(label_cols)
    pairs_per_row = (n_prompts * (n_prompts - 1)) // 2

    def differing_pairs_for_row(row):
        """
        For a Series of labels (length n_prompts), compute how many of the
        C(n_prompts, 2) pairs are different.
        """
        counts = row.value_counts(dropna=False)
        same_label_pairs = sum((c * (c - 1)) // 2 for c in counts)
        return pairs_per_row - same_label_pairs

    results = []
    # group and compute per-group stability
    grouped = labels_df.groupby(group_by, sort=True)
    for key, group in grouped:
        # compute differing pairs per document
        diffs_series = group[label_cols].apply(differing_pairs_for_row, axis=1)
        total_differences = int(diffs_series.sum())
        total_comparisons = pairs_per_row * len(group)
        stability = total_differences / total_comparisons if total_comparisons > 0 else float('nan')

        # construct result row mapping group_by columns to their values
        if isinstance(key, tuple):
            row = dict(zip(group_by, key))
        else:
            row = {group_by[0]: key}
        row['Stability'] = stability
        results.append(row)

    stability_df = pd.DataFrame(results)
    # keep consistent ordering
    if group_by:
        stability_df.sort_values(group_by, inplace=True, ignore_index=True)
    else:
        stability_df = stability_df.reset_index(drop=True)

    return stability_df

def plot_stability(data, title, save_path=None):
    """
    Plots stability data with a customizable title and a legend for colors.

    Parameters:
    - data: DataFrame or list of DataFrames containing "Checkpoint", "Stability", and optionally "Model" columns.
    - title: str, the title of the plot.
    - save_path: str, optional. File path to save the plot. If None, the plot is not saved.
    """
    sns.set_theme(style="whitegrid", palette="colorblind")

    # Check if data is a list of DataFrames and concatenate if needed
    if isinstance(data, list):
        data = pd.concat(data, ignore_index=True)

    # Create the plot
    plt.figure(figsize=(10, 6))
    ax = sns.lineplot(data=data, x="Checkpoint", y="Stability", hue="Model", marker="o")

    # Customize the plot
    plt.title(title, fontsize=18, fontweight='bold')
    plt.xlabel("Training Step", fontsize=16, fontweight='bold')
    plt.ylabel("P(Label Change|Hypothesis Change)", fontsize=16, fontweight='bold')

    # Remove top and right spines
    sns.despine(top=True, right=True)

    # Remove grid lines
    plt.grid(False)

    # Set y-axis limit
    plt.ylim(0, 0.07)

    # Customize legend: remove title
    legend = ax.get_legend()
    if legend:
        legend.set_title(None)

    # Save the plot if save_path is provided
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
    else:
    # Show the plot
        plt.show()

############
## Figure 9
############
sns.set_palette("colorblind")

plots = [
    (0, "Large",     "large_train",    "large_eval"),
    (2, "Base",      "base_train",     "base_eval"),
    (3, "Base (Modern BERT)",   "mb_base_train",  "mb_base_eval"),
    (1, "Large (Modern BERT)",  "mb_large_train", "mb_large_eval"),
]

fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True, sharey=False)
axes = axes.flatten()

for idx, model_name, train_col, eval_col in plots:
    ax = axes[idx]
    # plot training loss (solid)
    sns.lineplot(
        x="Step", y=train_col,
        data=loss, ax=ax, linewidth=2, label="Train", legend = False
    )
    # plot eval loss (dashed)
    sns.lineplot(
        x="Step", y=eval_col,
        data=loss, ax=ax, linewidth=2,
        linestyle="--", label="Eval", legend = False
    )
    # bold title
    ax.set_title(model_name, fontweight="bold", fontsize = 16)
    # remove per‑axis labels
    ax.set_xlabel("")
    ax.set_ylabel("")
    # y‑grid
    ax.grid(axis="y", linestyle="--", alpha=0.7)
    
    # custom y‐limits
    if idx in (0, 1):
        ax.set_ylim(0, 2)
    else:  # idx 2 & 3
        ax.set_ylim(0, 0.75)
    # remove x‑ticks for top row
    if idx in (0, 1):
        ax.tick_params(axis="x", which="both", length=0, labelbottom=False)
        # remove x‑ticks for top row
    if idx in (1, 3):
        ax.tick_params(axis="y", which="both", length=0, labelbottom=False)

    # increase tick‐label font size
    ax.tick_params(axis="both", which="major", labelsize=14)

fig.text(0.52, -0.02, "Step", ha="center", fontsize=16, fontweight = 'bold')
fig.text(-0.02, 0.48, "Loss", va="center", rotation="vertical", fontsize=16, fontweight = 'bold')

# build one legend for all plots
handles, labels = axes[-1].get_legend_handles_labels()
legend = fig.legend(
    handles, labels,
    loc=(0.42, .91),
    ncol=2,
    frameon=False,
    fontsize=14,
    title="Loss Type",
    title_fontsize=16
)
legend.get_title().set_fontweight("bold")
fig.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig(r'./figures/figure_9.png', dpi = 300, bbox_inches='tight')

############
## Table 16
############
# Models to calculate metrics for
columns = ['base_nli',
           'large_nli',
           'base_debate',
           'large_debate',
           'base_modern',
           'large_modern',
           'llama']

# Compile labels into a single dataset
ukp = pd.read_csv('./data/ukp_stance.csv')
ukp['task'] = 'UKP Stance'
topic = pd.read_csv('./data/ukp_topic.csv')
topic['task'] = 'UKP Topic'
rand = pd.read_csv('./data/rand_terror.csv')
rand['task'] = 'RAND Event'
delib = pd.read_csv('./data/deliberative_politics.csv')
delib['task'] = 'Deliberative Politics'

unseen = pd.concat([ukp, topic, rand, delib])

# Calculate metrics
unseen_metrics = metrics_with_errors(unseen, columns, group_by = 'task')

# Organize and clean 
unseen_metrics.reset_index(inplace = True)
unseen_metrics.rename({'Column':'Model'}, axis = 1, inplace = True)

unseen_metrics.replace({'base_nli': 'DeBERTa Base', 
                'large_nli':'DeBERTa Large', 
                'base_debate':'DEBATE Base',
                'large_debate':'DEBATE Large',
                'base_modern': 'DEBATE Base (MB)',
                'large_modern': 'DEBATE Large (MB)',
                'llama': 'Llama 3.1 8B'
           }, inplace = True)
# Create formatted table
task_order = ["UKP Stance", "UKP Topic", "RAND Event", "Deliberative Politics"]
model_order = [
    "DeBERTa Base",
    "DeBERTa Large",
    "DEBATE Base",
    "DEBATE Large",
    "DEBATE Base (MB)",
    "DEBATE Large (MB)",
    "Llama 3.1 8B"
]
unseen_formatted = format_f1_table(
    unseen_metrics,
    index_col='Model',
    task_col='Task',
    value_col='F1',
    se_col='F1_SE',
    task_order=task_order,
    model_order=model_order
)

# convert to string
unseen_string = unseen_formatted.to_string()

output_file = './tables/table_16.txt'
with open(output_file, 'w') as f:
    f.write(unseen_string)

############
## Table 18
############
nli = pd.read_csv('./data/nli_bench.csv')

columns = ['base_nli',
           'large_nli',
           'base_debate',
           'large_debate',
           'base_modern',
           'large_modern']

# Compute metrics by task and overall
overall = metrics_with_errors(nli, columns, group_by = None)
task = metrics_with_errors(nli, columns, group_by = 'task')
overall.reset_index(inplace = True)
overall['Task'] = 'overall'
task.reset_index(inplace = True)
# Format into a single dataframe
combined = pd.concat([overall, task])
combined.rename({'Column':'Model'}, axis = 1, inplace = True)
# Clean
combined.replace({'base_nli': 'DeBERTa Base', 
                'large_nli':'DeBERTa Large', 
                'base_debate':'DEBATE Base',
                'large_debate':'DEBATE Large',
                'base_modern': 'DEBATE Base (MB)',
                'large_modern': 'DEBATE Large (MB)',
                  'anli_r1': 'ANLI',
                  'mnli_m': 'MNLI',
                  'wanli': 'WANLI',
                'overall': 'Overall'
           }, inplace = True)
# Create table
model_order = ['DeBERTa Base', 'DeBERTa Large', 'DEBATE Base (MB)', 'DEBATE Large (MB)', 'DEBATE Base', 'DEBATE Large']
task_order = ['ANLI', 'MNLI', 'WANLI', 'Overall']
nli_formatted = format_f1_table(
    combined,
    index_col='Model',
    task_col='Task',
    value_col='F1',
    se_col='F1_SE',
    task_order=task_order,
    model_order=model_order
)

nli_string = nli_formatted.to_string()

output_file = './tables/table_18.txt'
with open(output_file, 'w') as f:
    f.write(nli_string)
    
############
## Table 19
############
ood = pd.read_csv('./data/out_domain_bench.csv')
ood.rename({'task_name':'task', 'labels':'entailment'}, axis = 1, inplace = True)

overall = metrics_with_errors(ood, columns, group_by = None)
task = metrics_with_errors(ood, columns, group_by = 'task')

overall.reset_index(inplace = True)
overall['Task'] = 'overall'
task.reset_index(inplace = True)

combined = pd.concat([overall, task])
combined.rename({'Column':'Model'}, axis = 1, inplace = True)

combined.replace({'base_nli': 'DeBERTa Base', 
                'large_nli':'DeBERTa Large', 
                'base_debate':'DEBATE Base',
                'large_debate':'DEBATE Large',
                'base_modern': 'DEBATE Base (MB)',
                'large_modern': 'DEBATE Large (MB)',
                'tweet_topic': 'Tweet Topics',
                'yahootopics': 'Yahoo Topics',
                'amazonpolarity': 'Amazon Sentiment',
                'rottentomatoes': 'Movie Sentiment',
                'agnews': 'AG News',
                'emotiondair': 'DAIR AI Emotions',
                'sst2': 'Stanford Sentiment 2',
                'go_emotions': 'Google Emotions',
                'overall': 'Overall'
           }, inplace = True)

# Create table
model_order = ['DeBERTa Base', 'DeBERTa Large', 'DEBATE Base', 'DEBATE Large', 'DEBATE Base (MB)', 'DEBATE Large (MB)']
task_order = ['AG News', 'Tweet Topics', 'Yahoo Topics', 'Amazon Sentiment', 'Movie Sentiment', 'DAIR AI Emotions', 'Google Emotions', 'Stanford Sentiment 2', 'Overall']
ood_formatted = format_f1_table(
    combined,
    index_col='Model',
    task_col='Task',
    value_col='F1',
    se_col='F1_SE',
    task_order=task_order,
    model_order=model_order
)

ood_string = ood_formatted.to_string()

output_file = './tables/table_19.txt'
with open(output_file, 'w') as f:
    f.write(ood_string)
    
############
## Figure 10
############
# hypothesis stability
stab_labs = pd.read_csv('./data/stability_labels.csv')
stab = compute_stability(stab_labs, group_by = ['Model', 'Checkpoint'])

plot_stability(stab, title = 'Hypothesis Stability', save_path = './figures/figure_10.png')

############
## Figure 11
############
alts = pd.read_csv('./data/alt_hypotheses.csv')
pred_cols = ['base_hyp', 'large_hyp', 'base_final', 'large_final']
metrics_df = metrics_with_errors(alts, preds = pred_cols, n_bootstrap=1000)

# num groups of bars
n = 2
# height for the bars
base_bars = metrics_df.iloc[[0,2], 2]
large_bars = metrics_df.iloc[[1,3], 2]
# errors
base_errors = metrics_df.iloc[[0,2], 5]  # SE column for Base
large_errors = metrics_df.iloc[[1,3], 5]

# colors
blue = sns.color_palette("colorblind")[0]
orange = sns.color_palette("colorblind")[1]

# figure size
plt.figure(figsize=(8,5))
# width bars
width = 0.25      

# Plotting
plt.bar(np.arange(n),  base_bars, width, label = 'Original Hypothesis',
        yerr=base_errors, capsize=10, error_kw={'elinewidth': 1, 'capsize': 10}, color = blue)
plt.bar(np.arange(n) + width, large_bars, width, label = 'Majority Classification',
        yerr=large_errors, capsize=10, error_kw={'elinewidth': 1, 'capsize': 10}, color = orange)

plt.grid(axis='y', linestyle='--', alpha=0.7)
# remove spines
sns.despine(top=True, right=True)

# format label
plt.ylabel('F1', fontweight = 'bold', size = 14)

# where to put the ticks, what to label them
plt.xticks(np.arange(n) + width/2, ('DEBATE Base', 'DEBATE Large'), fontweight='bold', fontsize = 14)
plt.yticks(ticks = [0, .2, .4, .6, .8, 1], labels = ['0%', '20%', '40%', '60%', '80%', '100%'])

# legend
plt.ylim(0, 1.1) # hopefully this gives enough room to put the legend in the top right
plt.legend(loc='upper right', prop={'size': 8.5})
plt.savefig(r'./figures/figure_11.png', dpi = 300)

############
## Figure 12
############
fs = pd.read_csv('./data/fewshot_overfit_res.csv')
fs.replace({'DEBATE MB Large': 'DEBATE Large (MB)', 
            'DEBATE MB Base': 'DEBATE Base (MB)'}, inplace = True)


fs.columns = ['MCC', 'F1', 'Accuracy', 'Model']

data = fs
modelorder = ["DEBATE Large", "DEBATE Large (MB)", "DEBATE Base", "DEBATE Base (MB)"]

data['Model'] = pd.Categorical(data['Model'], categories=modelorder, ordered=True)
bar_color = sns.color_palette("colorblind")[0]

# Set the figure size for better visibility
plt.figure(figsize=(14, 8))

# Create the box plot
sns.boxplot(
    x='F1',           # X-axis
    y='Model',        # Y-axis
    data=data,         # Your DataFrame
    fill=bar_color  # Use a colorblind-friendly palette
)

# Add points for each box
x_positions = [0.961, 0.950, 0.946, 0.931]
y_positions = range(len(modelorder))
plt.scatter(x_positions, y_positions, color=sns.color_palette("colorblind")[1], s=100, zorder=3)

sns.despine(top=True, right=True)
# Label the axes
plt.xlabel('F1', fontweight = 'bold', size = 24)
plt.ylabel('')
plt.yticks(fontsize=24, fontweight = 'bold')
plt.xticks(ticks = [.8, .9, 1], labels = ['80%', '90%', '100%'])
plt.grid(axis='x', which='both', linestyle='--', alpha=0.7)
# Improve layout
plt.tight_layout()
plt.savefig(r'./figures/figure_12.png', dpi = 300)